/*
 * Copyright (c) 2023 Huawei Device Co., Ltd.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "client_base.h"

#include <stdbool.h>

#include <securec.h>
#include <tee_service_public.h>
#include <tee_sharemem_ops.h>

#include "logger.h"
#include "se_base_services_defines.h"

#define MAX_DATA_LENGTH 512U

static uint8_t *CreateShareMemory(const uint8_t *data, uint32_t dataLen, uint32_t maxLen)
{
    // data could be nullptr
    if (dataLen > maxLen || maxLen == 0 || maxLen > MAX_DATA_LENGTH) {
        LOG_ERROR("CreateShareMemory failed with dataLen=%u, maxLen=%u", dataLen, maxLen);
        return NULL;
    }

    TEE_UUID uuid = SERVICE_UUID;
    uint8_t *memory = tee_alloc_sharemem_aux(&uuid, maxLen);
    if (memory == NULL) {
        LOG_ERROR("tee_alloc_sharemem_aux failed with size %u", dataLen);
        return NULL;
    }

    if (memset_s(memory, maxLen, 0, maxLen) != EOK) {
        LOG_ERROR("memset_s failed with size %u", maxLen);
        tee_free_sharemem(memory, maxLen);
        return NULL;
    }

    if (data == NULL || dataLen == 0) {
        return memory;
    }

    if (memcpy_s(memory, maxLen, data, dataLen) != EOK) {
        LOG_ERROR("tee_alloc_sharemem_aux failed with size %u", dataLen);
        return NULL;
    }
    return memory;
}

static void DestroyShareMemory(uint8_t *memory, uint32_t maxLen)
{
    if (memory == NULL) {
        return;
    }

    if (memset_s(memory, maxLen, 0, maxLen) != EOK) {
        LOG_ERROR("memset_s failed with size %u", maxLen);
    }

    tee_free_sharemem(memory, maxLen);
}

static bool FillUpReplyBuffer(const tee_service_ipc_msg_rsp *rsp, const uint8_t *buffer, uint32_t bufferMaxLen,
    uint8_t *reply, uint32_t *replyLen)
{
    if (rsp == NULL || buffer == NULL) {
        LOG_ERROR("TryToFillReplyBuffer failed, rsp or buffer is null");
        return false;
    }

    if (reply == NULL && replyLen == NULL) {
        return true;
    }

    uint32_t bufferLen = (uint32_t)rsp->msg.args_data.arg4;
    if (bufferLen == 0) {
        return true;
    }

    if (bufferLen > bufferMaxLen) {
        LOG_ERROR("TryToFillReplyBuffer failed with bufferLen %u, bufferMaxLen %u", bufferLen, bufferMaxLen);
        return false;
    }

    if (reply == NULL || replyLen == NULL) {
        LOG_ERROR("TryToFillReplyBuffer failed, reply or replyLen is null");
        return false;
    }

    if (memcpy_s(reply, *replyLen, buffer, bufferLen) != EOK) {
        LOG_ERROR("TryToFillReplyBuffer memcpy_s failed with bufferLen %u, replyLen %u", bufferLen, *replyLen);
        return false;
    }

    *replyLen = bufferLen;
    return true;
}

ResultCode SendRequestToServer(uint32_t cmd, const uint8_t *data, uint32_t dataLen, uint8_t *reply, uint32_t *replyLen)
{
    // data could be nullptr
    if (dataLen > MAX_DATA_LENGTH) {
        LOG_ERROR("SendRequestToServer invalid input data or data len");
        return INVALID_PARA_ERR_SIZE;
    }

    // reply and replyLen could be null
    if (replyLen != NULL && *replyLen > MAX_DATA_LENGTH) {
        LOG_ERROR("invalid input reply or replyLen");
        return INVALID_PARA_ERR_SIZE;
    }

    LOG_DEBUG("SendRequestToServer invoke begin, cmd = 0x%x", cmd);

    TEE_Result ret = TEE_FAIL;
    uint8_t *buffer = NULL;
    do {
        buffer = CreateShareMemory(data, dataLen, MAX_DATA_LENGTH);
        if (buffer == NULL) {
            ret = MEM_ALLOC_ERR;
            LOG_ERROR("SendRequestToServer tee_alloc_sharemem_aux failed with size %u", dataLen);
            break;
        }

        tee_service_ipc_msg msg = {{0}};
        msg.args_data.arg0 = cmd;                           // the cmd id
        msg.args_data.arg1 = MAX_DATA_LENGTH;               // the max shared buff length
        msg.args_data.arg2 = (uint64_t)dataLen;             // cmd buff length
        msg.args_data.arg3 = (uint64_t)(uintptr_t)(buffer); // cmd buff
        msg.args_data.arg4 = 0;                             // rsp length(will filled by server)
        tee_service_ipc_msg_rsp rsp = {TEE_FAIL, {0}};
        tee_common_ipc_proc_cmd(SERVICE_NAME, 0, &msg, 0, &rsp);
        ret = rsp.ret;
        if (ret != SUCCESS) {
            LOG_ERROR("SendRequestToServer tee_common_ipc_proc_cmd failed with error 0x%x", ret);
            break;
        }

        if (!FillUpReplyBuffer(&rsp, buffer, MAX_DATA_LENGTH, reply, replyLen)) {
            LOG_ERROR("SendRequestToServer FillUpReplyBuffer failed with error");
            ret = IPC_RES_FILL_UP_ERR;
            break;
        }
    } while (0);

    DestroyShareMemory(buffer, MAX_DATA_LENGTH);

    LOG_DEBUG("SendRequestToServer invoke finish, cmd = 0x%x, ret = 0x%x", cmd, ret);
    return ret;
}
